package com.troila.xfspark.websocket;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.troila.xfspark.model.SparkProperties;
import com.troila.xfspark.model.SparkRequest;
import com.troila.xfspark.model.SparkRoleContent;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

@Slf4j
public class WebSocketMessageHandler extends AbstractWebSocketHandler {

    private SparkProperties properties;

    private ObjectMapper objectMapper;

    private WebSocketHandler sparkWebSocketHandler;

    @Autowired
    public void setProperties(SparkProperties properties) {
        this.properties = properties;
    }

    @Autowired
    public void setObjectMapper(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    @Autowired
    @Qualifier("sparkWebSocketHandler")
    public void setSparkWebSocketHandler(WebSocketHandler sparkWebSocketHandler) {
        this.sparkWebSocketHandler = sparkWebSocketHandler;
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        super.afterConnectionEstablished(session);
        WebSocketSessionManager.put(session.getId(), session);
        logger.info("[local websocket] 连接成功，session = {}", session.getId());
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        super.afterConnectionClosed(session, status);
        WebSocketSessionManager.remove(session.getId());
        logger.info("[local websocket] 退出连接");
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        super.handleTextMessage(session, message);
        logger.info("[local websocket] 收到消息：{}", message.getPayload());

        SparkRoleContent content = parseObject(message.getPayload(), new TypeReference<SparkRoleContent>() {
        });
        if (Objects.isNull(WebSocketSessionManager.getSparkSession())) {
            WebSocketClient client = new StandardWebSocketClient();
            client.doHandshake(sparkWebSocketHandler, SparkWrapper.getInstance().websocketUrl(properties))
                    .addCallback(new ListenableFutureCallback<>() {
                        @Override
                        public void onSuccess(WebSocketSession result) {
                            logger.info("[spark websocket] 连接成功2");
                            try {
                                sendToSpark(result, content);
                            } catch (IOException e) {
                                logger.error(e.getMessage());
                            }
                        }

                        @Override
                        public void onFailure(Throwable ex) {
                            logger.info("[spark websocket] 发生错误：{}", ex.getMessage());
                        }
                    });
        } else {
            sendToSpark(WebSocketSessionManager.getSparkSession(), content);
        }
    }

    private void sendToSpark(WebSocketSession session, SparkRoleContent content) throws IOException {
        if (Objects.nonNull(content)) {
            SparkRequest request = SparkRequest.builder()
                    .header(SparkRequest.SparkHeader.builder()
                            .appId(properties.getAppId())
                            .build())
                    .parameter(SparkRequest.SparkParameter.builder()
                            .chat(SparkRequest.SparkChat.builder()
                                    .chatId(session.getId())
                                    .build())
                            .build())
                    .payload(SparkRequest.SparkPayload.builder()
                            .message(SparkRequest.SparkMessage.builder()
                                    .text(List.of(content))
                                    .build())
                            .build())
                    .build();
            String message = objectMapper.writeValueAsString(request);
            logger.info("[spark websocket] 发送讯飞消息：{}", message);
            session.sendMessage(new TextMessage(message));
        }
    }

    private <T> T parseObject(String json, TypeReference<T> valueTypeRef) {
        try {
            return objectMapper.readValue(json, valueTypeRef);
        } catch (JsonProcessingException e) {
            logger.error(e.getMessage());
            return null;
        }
    }


}
